// Entry point
//

#include <iostream>
#include <fstream>
#include <mutex>
#include <thread>
#include <iomanip>
#include "Analysis/problem_analysis.h"
#include "Analysis/domain_analysis.h"
#include "Example Problems/City Exploration/city_exploration_policy_generators.h"
#include "Example Problems/City Exploration/city_exploration_domain.h"
#include "Example Problems/City Exploration/city_exploration_problem.h"
#include "Example Problems/Graph Rock Sample/graph_rock_sample_domain.h"
#include "Example Problems/Graph Rock Sample/graph_rock_sample_problem.h"
#include "Policy/BSQ/belief_state_query.h"
#include "Policy/Rule Nodes/parameter_value_inequality.h"
#include "Policy/rule.h"
#include "Policy/policy.h"
#include "interval_solver.h"
#include "full_belief_state.h"
#include "Analysis/policy_analysis.h"
#include "eval_policy_dfs.h"
#include "Example Problems/Lane Merger/lane_merger_policy_generator.h"
#include "Example Problems/Lane Merger/lane_merger_domain.h"
#include "Example Problems/Lane Merger/lane_merger_problem.h"
#include "Example Problems/Spaceship Repair/spaceship_repair_domain.h"
#include "Example Problems/Spaceship Repair/spaceship_repair_problem.h"


//Source of code for detecting memory leaks: https://stackoverflow.com/questions/4790564/finding-memory-leaks-in-a-c-application-with-visual-studio
#include "windows.h"
#define _CRTDBG_MAP_ALLOC
#include <stdlib.h>  
#include <crtdbg.h>

enum defined_problems { LM, GRS, SR, SV };

void run_solver(defined_problems selected_problem, interval_solver_type selected_solver, std::string solutions_file_path, std::string interval_file_path, std::string time_samples_path, unsigned int run_count, unsigned int horizon, unsigned int min_time, unsigned int max_time, unsigned thread_number, unsigned int time_sample_count)
{
    Problem* chosen_problem = NULL;
    Domain* chosen_domain = NULL;
    std::string policy_path = "";

    switch (selected_problem) {
    case LM:
        chosen_problem = new LaneMerger();
        chosen_domain = new LaneMergerDomain();
        policy_path = "Example Problems/Lane Merger/Policies/lane_merger_policy3.txt";
        break;
    case GRS:
        chosen_problem = new GraphRockSample();
        chosen_domain = new GraphRockSampleDomain();
        policy_path = "Example Problems/Graph Rock Sample/Policies/graph_rock_sample_policy1.txt";
        break;
    case SR:
        chosen_problem = new SpaceshipRepair();
        chosen_domain = new SpaceshipRepairDomain();
        policy_path = "Example Problems/Spaceship Repair/Policies/spaceship_repair_policy.txt";
        break;
    case SV:
        chosen_problem = new CityExploration();
        chosen_domain = new CityExplorationDomain();
        policy_path = "Example Problems/City Exploration/Policies/CityExplorationPolicy2.txt";
        break;
    default:

        return;
        break;
    }

    FullBeliefState current_belief(*chosen_domain);
    Policy test_policy(policy_path, *chosen_problem);
    for (int i = 0; i < run_count; ++i) {
        interval_solver(thread_number, horizon, &current_belief, test_policy, *chosen_problem, chosen_domain, solutions_file_path, interval_file_path, 3u, min_time, max_time, selected_solver, true, time_samples_path, time_sample_count);
    }

    if (chosen_problem) {
        delete chosen_problem;
        chosen_problem = NULL;
    }
    if (chosen_domain) {
        delete chosen_domain;
        chosen_domain = NULL;
    }
}


void policy_evaluator(defined_problems selected_problem, std::string solutions_file_path, std::string results_file_path, unsigned int horizon, unsigned int evaluation_count, unsigned int random_seed, unsigned int thread_number)
{
    Problem* chosen_problem = NULL;
    Domain* chosen_domain = NULL;
    std::string policy_path = "";

    switch (selected_problem) {
    case LM:
        chosen_problem = new LaneMerger();
        chosen_domain = new LaneMergerDomain();
        policy_path = "Example Problems/Lane Merger/Policies/lane_merger_policy3.txt";
        break;
    case GRS:
        chosen_problem = new GraphRockSample();
        chosen_domain = new GraphRockSampleDomain();
        policy_path = "Example Problems/Graph Rock Sample/Policies/graph_rock_sample_policy1.txt";
        break;
    case SR:
        chosen_problem = new SpaceshipRepair();
        chosen_domain = new SpaceshipRepairDomain();
        policy_path = "Example Problems/Spaceship Repair/Policies/spaceship_repair_policy.txt";
        break;
    case SV:
        chosen_problem = new CityExploration();
        chosen_domain = new CityExplorationDomain();
        policy_path = "Example Problems/City Exploration/Policies/CityExplorationPolicy2.txt";
        break;
    default:

        return;
        break;
    }

    FullBeliefState current_belief(*chosen_domain);
    Policy test_policy(policy_path, *chosen_problem);
    policy_analysis_parrallel(&current_belief, test_policy, *chosen_problem, solutions_file_path, horizon, 1.0f, evaluation_count, results_file_path, random_seed, thread_number);

    if (chosen_problem) {
        delete chosen_problem;
        chosen_problem = NULL;
    }
    if (chosen_domain) {
        delete chosen_domain;
        chosen_domain = NULL;
    }
}

void rcompliant_policy_evaluator(defined_problems selected_problem, std::string results_file_path, unsigned int horizon, unsigned int evaluation_count, unsigned int random_seed, unsigned int thread_number, unsigned int policy_count)
{
    Problem* chosen_problem = NULL;
    Domain* chosen_domain = NULL;
    std::string policy_path = "";

    switch (selected_problem) {
    case LM:
        chosen_problem = new LaneMerger();
        chosen_domain = new LaneMergerDomain();
        policy_path = "Example Problems/Lane Merger/Policies/lane_merger_policy3.txt";
        break;
    case GRS:
        chosen_problem = new GraphRockSample();
        chosen_domain = new GraphRockSampleDomain();
        policy_path = "Example Problems/Graph Rock Sample/Policies/graph_rock_sample_policy1.txt";
        break;
    case SR:
        chosen_problem = new SpaceshipRepair();
        chosen_domain = new SpaceshipRepairDomain();
        policy_path = "Example Problems/Spaceship Repair/Policies/spaceship_repair_policy.txt";
        break;
    case SV:
        chosen_problem = new CityExploration();
        chosen_domain = new CityExplorationDomain();
        policy_path = "Example Problems/City Exploration/Policies/CityExplorationPolicy2.txt";
        break;
    default:

        return;
        break;
    }

    FullBeliefState current_belief(*chosen_domain);
    Policy test_policy(policy_path, *chosen_problem);
    random_policy_analysis_parrallel(&current_belief, test_policy, *chosen_problem, horizon, 1.0f, evaluation_count, results_file_path, random_seed, thread_number, policy_count);
    if (chosen_problem) {
        delete chosen_problem;
        chosen_problem = NULL;
    }
    if (chosen_domain) {
        delete chosen_domain;
        chosen_domain = NULL;
    }
}

void test_spaceship_worker(Problem* current_problem, Domain* current_domain, BeliefState* current_belief, Policy* current_policy, unsigned int horizon,
    std::mutex* input_lock, std::mutex* output_lock, std::list<std::vector<float>>* input_list, std::list<std::pair<std::vector<float>, float>>* output_list, bool* continue_loop)
{
    Problem* my_problem = current_problem->create_copy();
    Domain* my_domain = current_domain->create_copy();
    BeliefState* my_belief = current_belief->create_copy();
    Policy my_policy(*current_policy);

    std::vector<float> current_input;
    float current_average = 0.0f;
    bool found_input = false;

    while (*continue_loop) {
        input_lock->lock();
        found_input = !input_list->empty();
        if (found_input) {
            current_input = input_list->front();
            input_list->pop_front();
        }
        input_lock->unlock();

        if (found_input) {
            current_average = 0.0f;
            for (int i = 0; i < 300; ++i) {
                current_average += evalPolicyDFS(my_belief, my_policy, *my_problem, current_input, horizon, 1.0f);
            }
            current_average /= 300.0f;
            output_lock->lock();
            output_list->emplace_back(current_input, current_average);
            output_lock->unlock();
        }
    }

    delete my_problem;
    delete my_domain;
    delete my_belief;
}

void test_spaceship_heatmap() {
    SpaceshipRepair test_problem;
    SpaceshipRepairDomain test_domain;
    FullBeliefState test_belief(test_domain);
    std::string policy_file_path = "Example Problems/Spaceship Repair/Policies/spaceship_repair_policy.txt";
    std::string solve_path = "Results/heatmap_spaceship_repair.txt";
    unsigned int thread_number = 10u;
    unsigned int horizon = 12u;
    std::vector<float> parameter_values(2);
    float current_average;
    Policy test_policy(policy_file_path, test_problem);
    std::ofstream output_file;

    std::pair<float, float> x_range = { 0.0f, 1.0f }, y_range = { 0.0f, 1.0f };
    //float point_number = 100.0f;

    std::list<std::vector<float>> input_list;
    std::list<std::pair<std::vector<float>, float>> output_list;
    std::mutex input_lock, output_lock;
    bool keep_looping = true;
    std::list<std::thread> thread_list;

    for (int i = 0; i < thread_number; ++i) {
        thread_list.emplace_back(test_spaceship_worker, &test_problem, &test_domain, &test_belief, &test_policy, horizon, &input_lock, &output_lock, &input_list, &output_list, &keep_looping);
    }

    for (float x = 0.0f; x <= 1.0f; x += 0.002) {
        for (float y = 0.0f; y <= 1.0f; y += 0.002) {
            input_lock.lock();
            parameter_values[0] = x;
            parameter_values[1] = y;
            input_list.emplace_back(parameter_values);
            input_lock.unlock();
        }
    }

    int count = 251001;
    bool found_output = false;
    output_file.open(solve_path);
    while (count > 0) {
        output_lock.lock();
        found_output = !output_list.empty();
        if (found_output) {
            parameter_values = output_list.front().first;
            current_average = output_list.front().second;
            output_list.pop_front();
        }
        output_lock.unlock();
        if (found_output) {
            std::cout << parameter_values[0] << "," << parameter_values[1] << "," << current_average << "," << count << std::endl;
            output_file << std::setprecision(4) << parameter_values[0] << "," << parameter_values[1] << ",";
            output_file << std::setprecision(10) << current_average << std::endl;
            --count;
        }
    }
    output_file.close();
    keep_looping = false;
    for (std::list<std::thread>::iterator it = thread_list.begin(); it != thread_list.end(); ++it) {
        it->join();
    }
    thread_list.clear();
}

void run_all_experiments(const std::map<std::string,defined_problems> &problems, const std::map<std::string,interval_solver_type> &solvers) {

    unsigned int evaluation_random_seed = 76421903u;   //Random seed to use for setting the initial states while evaluating.
    unsigned int evaluation_thread_number = 16u;       //Number of worker processes running in parallel while evaluating.
    unsigned int horizon = 100u;                       //The horizon of the problem.
    unsigned int max_solve_time = 1560u;               //Maximum time to run the solver. This only occurs if the hypothesized optimal partition has less than the minimum.
    unsigned int min_solve_time = 1500u;               //Minimum time to run the solver.
    unsigned int rconfident_policy_count = 10u;        //Number of times to evaluate RConfident for each problem.
    unsigned int solution_evaluation_count = 25000u;   //Number of runs to evaluate solutions on.
    unsigned int solver_run_count = 10u;               //Number of times to run the solver on each problem.
    unsigned int solve_thread_number = 8u;             //Number of worker processes running in parallel while running the solver.
    unsigned int time_slice_count = 120u;              //Number of times to record the hypothesized optimal partition while solving.
    unsigned int time_slice_evaluation_count = 10000u; //Number of run to evaluate record hypothesized optimal partitions on.

    std::string solution_file_path, rconfident_eval_path, interval_file_path, time_slice_path, solutions_eval_path, time_slice_eval_path;
    for (std::map<std::string, defined_problems>::const_iterator problem_it = problems.cbegin(); problem_it != problems.cend(); ++problem_it) {
        std::cout << "Current Problem: " << problem_it->first << std::endl;
        rconfident_eval_path = "Results/RConfident/" + problem_it->first + ".txt";
        for (std::map<std::string, interval_solver_type>::const_iterator solver_it = solvers.cbegin(); solver_it != solvers.cend(); ++solver_it) {
            solution_file_path = "Results/" + solver_it->first + "/" + problem_it->first + "_solutions.txt";
            interval_file_path = "Results/" + solver_it->first + "/" + problem_it->first + "_intervals.txt";
            time_slice_path = "Results/" + solver_it->first + "/" + problem_it->first + "_time_slices.txt";
            solutions_eval_path = "Results/" + solver_it->first + "/" + problem_it->first + "_solutions_evaluated.txt";
            time_slice_eval_path = "Results/" + solver_it->first + "/" + problem_it->first + "_time_slices_evaluated.txt";

            run_solver(problem_it->second, solver_it->second, solution_file_path, interval_file_path, time_slice_path, solver_run_count, horizon, min_solve_time, max_solve_time, solve_thread_number, time_slice_count);
            policy_evaluator(problem_it->second, solution_file_path, solutions_eval_path, horizon, solution_evaluation_count, evaluation_random_seed, evaluation_thread_number);
            policy_evaluator(problem_it->second, time_slice_path, time_slice_eval_path, horizon, time_slice_evaluation_count, evaluation_random_seed, evaluation_thread_number);
        }

        rcompliant_policy_evaluator(problem_it->second, rconfident_eval_path, horizon, solution_evaluation_count, evaluation_random_seed, evaluation_thread_number, rconfident_policy_count);
    }

    test_spaceship_heatmap();
}

int main()
{
    _CrtMemState sOld;
    _CrtMemState sNew;
    _CrtMemState sDiff;
    _CrtMemCheckpoint(&sOld);

    //{"lane_merger",LM},{"graph_rock_sample",GRS}
    std::map<std::string, defined_problems> problem_mapping = { {"lane_merger",LM},{"graph_rock_sample",GRS},{"spaceship_repair",SR},{"store_vist",SV} };
    std::map<std::string, interval_solver_type> solver_mapping = { {"Boltzmann Exploration",softmax},{"Confidence Maximization",certainty_max},{"Epsilon Greedy",epsilon_greedy},{"Global Thompson",global_thompson},{"Local Thompson",local_thompson} };


    run_all_experiments(problem_mapping,solver_mapping);

    _CrtMemCheckpoint(&sNew);
    if (_CrtMemDifference(&sDiff, &sOld, &sNew))
    {
        OutputDebugString(L"-----------_CrtMemDumpStatistics ---------");
        _CrtMemDumpStatistics(&sDiff);
        OutputDebugString(L"-----------_CrtMemDumpAllObjectsSince ---------");
        _CrtMemDumpAllObjectsSince(&sOld);
        OutputDebugString(L"-----------_CrtDumpMemoryLeaks ---------");
        _CrtDumpMemoryLeaks();
    }

    return 0;
}
